import random
import sys
from typing import DefaultDict, Dict, List

import grpc
import httpx
import pytest
from fastapi import FastAPI

import ray
from ray import serve
from ray._common.test_utils import SignalActor, wait_for_condition
from ray.serve._private.constants import DEFAULT_LATENCY_BUCKET_MS
from ray.serve._private.test_utils import (
    get_application_url,
    ping_fruit_stand,
    ping_grpc_call_method,
)
from ray.serve.handle import DeploymentHandle
from ray.serve.metrics import Counter, Gauge, Histogram
from ray.serve.tests.test_config_files.grpc_deployment import g, g2
from ray.serve.tests.test_metrics import (
    check_metric_float_eq,
    check_sum_metric_eq,
    get_metric_dictionaries,
)


@serve.deployment
class WaitForSignal:
    async def __call__(self):
        signal = ray.get_actor("signal123")
        await signal.wait.remote()


@serve.deployment
class Router:
    def __init__(self, handles):
        self.handles = handles

    async def __call__(self, index: int):
        return await self.handles[index - 1].remote()


@ray.remote
def call(deployment_name, app_name, *args):
    handle = DeploymentHandle(deployment_name, app_name)
    handle.remote(*args)


@ray.remote
class CallActor:
    def __init__(self, deployment_name: str, app_name: str):
        self.handle = DeploymentHandle(deployment_name, app_name)

    async def call(self, *args):
        await self.handle.remote(*args)


class TestRequestContextMetrics:
    def _generate_metrics_summary(self, metrics: List[Dict]):
        """Generate "route" and "application" information from metrics.

        Args:
            metrics: List of metric dictionaries, each generated by the
                get_metric_dictionaries function.
        Returns:
            Tuple[dict, dict]:
                - The first dictionary maps deployment names to a set of routes.
                - The second dictionary maps deployment names to application names.
        """
        metrics_summary_route = DefaultDict(set)
        metrics_summary_app = DefaultDict(str)

        for request_metrics in metrics:
            metrics_summary_route[request_metrics["deployment"]].add(
                request_metrics["route"]
            )
            metrics_summary_app[request_metrics["deployment"]] = request_metrics[
                "application"
            ]
        return metrics_summary_route, metrics_summary_app

    def verify_metrics(self, metric, expected_output):
        for key in expected_output:
            assert metric[key] == expected_output[key]

    def test_request_context_pass_for_http_proxy(self, metrics_start_shutdown):
        """Test HTTP proxy passing request context"""

        @serve.deployment(graceful_shutdown_timeout_s=0.001)
        def f():
            return "hello"

        @serve.deployment(graceful_shutdown_timeout_s=0.001)
        def g():
            return "world"

        @serve.deployment(graceful_shutdown_timeout_s=0.001)
        def h():
            return 1 / 0

        serve.run(f.bind(), name="app1", route_prefix="/app1")
        serve.run(g.bind(), name="app2", route_prefix="/app2")
        serve.run(h.bind(), name="app3", route_prefix="/app3")

        resp = httpx.get("http://127.0.0.1:8000/app1")
        assert resp.status_code == 200
        assert resp.text == "hello"
        resp = httpx.get("http://127.0.0.1:8000/app2")
        assert resp.status_code == 200
        assert resp.text == "world"
        resp = httpx.get("http://127.0.0.1:8000/app3")
        assert resp.status_code == 500

        wait_for_condition(
            lambda: len(
                get_metric_dictionaries("serve_deployment_processing_latency_ms_sum")
            )
            == 3,
            timeout=40,
        )

        def wait_for_route_and_name(
            metric_name: str,
            deployment_name: str,
            app_name: str,
            route: str,
            timeout: float = 5,
        ):
            """Waits for app name and route to appear in deployment's metric."""

            def check():
                # Check replica qps & latency
                (
                    qps_metrics_route,
                    qps_metrics_app_name,
                ) = self._generate_metrics_summary(get_metric_dictionaries(metric_name))
                assert qps_metrics_app_name[deployment_name] == app_name
                assert qps_metrics_route[deployment_name] == {route}
                return True

            wait_for_condition(check, timeout=timeout)

        # Check replica qps & latency
        wait_for_route_and_name(
            "serve_deployment_request_counter", "f", "app1", "/app1"
        )
        wait_for_route_and_name(
            "serve_deployment_request_counter", "g", "app2", "/app2"
        )
        wait_for_route_and_name("serve_deployment_error_counter", "h", "app3", "/app3")

        # Check http proxy qps & latency
        for metric_name in [
            "serve_num_http_requests",
            "serve_http_request_latency_ms_sum",
        ]:
            metrics = get_metric_dictionaries(metric_name)
            assert {metric["route"] for metric in metrics} == {
                "/app1",
                "/app2",
                "/app3",
            }

        for metric_name in [
            "serve_handle_request_counter",
            "serve_num_router_requests",
            "serve_deployment_processing_latency_ms_sum",
        ]:
            metrics_route, metrics_app_name = self._generate_metrics_summary(
                get_metric_dictionaries(metric_name)
            )
            msg = f"Incorrect metrics for {metric_name}"
            assert metrics_route["f"] == {"/app1"}, msg
            assert metrics_route["g"] == {"/app2"}, msg
            assert metrics_route["h"] == {"/app3"}, msg
            assert metrics_app_name["f"] == "app1", msg
            assert metrics_app_name["g"] == "app2", msg
            assert metrics_app_name["h"] == "app3", msg

    def test_request_context_pass_for_grpc_proxy(self, metrics_start_shutdown):
        """Test gRPC proxy passing request context"""

        @serve.deployment(graceful_shutdown_timeout_s=0.001)
        class H:
            def __call__(self, *args, **kwargs):
                return 1 / 0

        h = H.bind()
        app_name1 = "app1"
        depl_name1 = "grpc-deployment"
        app_name2 = "app2"
        depl_name2 = "grpc-deployment-model-composition"
        app_name3 = "app3"
        depl_name3 = "H"
        serve.run(g, name=app_name1, route_prefix="/app1")
        serve.run(g2, name=app_name2, route_prefix="/app2")
        serve.run(h, name=app_name3, route_prefix="/app3")

        channel = grpc.insecure_channel("localhost:9000")
        ping_grpc_call_method(channel, app_name1)
        ping_fruit_stand(channel, app_name2)
        with pytest.raises(grpc.RpcError):
            ping_grpc_call_method(channel, app_name3)

        # app1 has 1 deployment, app2 has 3 deployments, and app3 has 1 deployment.
        wait_for_condition(
            lambda: len(
                get_metric_dictionaries("serve_deployment_processing_latency_ms_sum")
            )
            == 5,
            timeout=40,
        )

        def wait_for_route_and_name(
            _metric_name: str,
            deployment_name: str,
            app_name: str,
            route: str,
            timeout: float = 5,
        ):
            """Waits for app name and route to appear in deployment's metric."""

            def check():
                # Check replica qps & latency
                (
                    qps_metrics_route,
                    qps_metrics_app_name,
                ) = self._generate_metrics_summary(
                    get_metric_dictionaries(_metric_name)
                )
                assert qps_metrics_app_name[deployment_name] == app_name
                assert qps_metrics_route[deployment_name] == {route}
                return True

            wait_for_condition(check, timeout=timeout)

        # Check replica qps & latency
        wait_for_route_and_name(
            "serve_deployment_request_counter", depl_name1, app_name1, app_name1
        )
        wait_for_route_and_name(
            "serve_deployment_request_counter", depl_name2, app_name2, app_name2
        )
        wait_for_route_and_name(
            "serve_deployment_error_counter", depl_name3, app_name3, app_name3
        )

        # Check grpc proxy qps & latency
        for metric_name in [
            "serve_num_grpc_requests",
            "serve_grpc_request_latency_ms_sum",
        ]:
            metrics = get_metric_dictionaries(metric_name)
            assert {metric["route"] for metric in metrics} == {
                "app1",
                "app2",
                "app3",
            }

        for metric_name in [
            "serve_handle_request_counter",
            "serve_num_router_requests",
            "serve_deployment_processing_latency_ms_sum",
        ]:
            metrics_route, metrics_app_name = self._generate_metrics_summary(
                get_metric_dictionaries(metric_name)
            )
            msg = f"Incorrect metrics for {metric_name}"
            assert metrics_route[depl_name1] == {"app1"}, msg
            assert metrics_route[depl_name2] == {"app2"}, msg
            assert metrics_route[depl_name3] == {"app3"}, msg
            assert metrics_app_name[depl_name1] == "app1", msg
            assert metrics_app_name[depl_name2] == "app2", msg
            assert metrics_app_name[depl_name3] == "app3", msg

    def test_request_context_pass_for_handle_passing(self, metrics_start_shutdown):
        """Test handle passing contexts between replicas"""

        @serve.deployment
        def g1():
            return "ok1"

        @serve.deployment
        def g2():
            return "ok2"

        app = FastAPI()

        @serve.deployment
        @serve.ingress(app)
        class G:
            def __init__(self, handle1: DeploymentHandle, handle2: DeploymentHandle):
                self.handle1 = handle1
                self.handle2 = handle2

            @app.get("/api")
            async def app1(self):
                return await self.handle1.remote()

            @app.get("/api2")
            async def app2(self):
                return await self.handle2.remote()

        serve.run(G.bind(g1.bind(), g2.bind()), name="app")
        app_url = get_application_url("HTTP", "app")
        resp = httpx.get(f"{app_url}/api")
        assert resp.text == '"ok1"'
        resp = httpx.get(f"{app_url}/api2")
        assert resp.text == '"ok2"'

        # G deployment metrics:
        #   {xxx, route:/api}, {xxx, route:/api2}
        # g1 deployment metrics:
        #   {xxx, route:/api}
        # g2 deployment metrics:
        #   {xxx, route:/api2}
        wait_for_condition(
            lambda: len(get_metric_dictionaries("serve_deployment_request_counter"))
            == 4,
            timeout=40,
        )
        (
            requests_metrics_route,
            requests_metrics_app_name,
        ) = self._generate_metrics_summary(
            get_metric_dictionaries("serve_deployment_request_counter")
        )
        assert requests_metrics_route["G"] == {"/api", "/api2"}
        assert requests_metrics_route["g1"] == {"/api"}
        assert requests_metrics_route["g2"] == {"/api2"}
        assert requests_metrics_app_name["G"] == "app"
        assert requests_metrics_app_name["g1"] == "app"
        assert requests_metrics_app_name["g2"] == "app"

    @pytest.mark.parametrize("route_prefix", ["", "/prefix"])
    def test_fastapi_route_metrics(self, metrics_start_shutdown, route_prefix: str):
        app = FastAPI()

        @serve.deployment
        @serve.ingress(app)
        class A:
            @app.get("/api")
            def route1(self):
                return "ok1"

            @app.get("/api2/{user_id}")
            def route2(self):
                return "ok2"

        if route_prefix:
            serve.run(A.bind(), route_prefix=route_prefix)
        else:
            serve.run(A.bind())

        base_url = get_application_url("HTTP")
        resp = httpx.get(f"{base_url}/api")
        assert resp.text == '"ok1"'
        resp = httpx.get(f"{base_url}/api2/abc123")
        assert resp.text == '"ok2"'

        wait_for_condition(
            lambda: len(get_metric_dictionaries("serve_deployment_request_counter"))
            == 2,
            timeout=40,
        )
        (
            requests_metrics_route,
            requests_metrics_app_name,
        ) = self._generate_metrics_summary(
            get_metric_dictionaries("serve_deployment_request_counter")
        )
        assert requests_metrics_route["A"] == {
            route_prefix + "/api",
            route_prefix + "/api2/{user_id}",
        }

    def test_customer_metrics_with_context(self, metrics_start_shutdown):
        @serve.deployment
        class Model:
            def __init__(self):
                self.counter = Counter(
                    "my_counter",
                    description="my counter metrics",
                    tag_keys=(
                        "my_static_tag",
                        "my_runtime_tag",
                        "route",
                    ),
                )
                self.counter.set_default_tags({"my_static_tag": "static_value"})
                self.histogram = Histogram(
                    "my_histogram",
                    description=("my histogram "),
                    boundaries=DEFAULT_LATENCY_BUCKET_MS,
                    tag_keys=(
                        "my_static_tag",
                        "my_runtime_tag",
                        "route",
                    ),
                )
                self.histogram.set_default_tags({"my_static_tag": "static_value"})
                self.gauge = Gauge(
                    "my_gauge",
                    description=("my_gauge"),
                    tag_keys=(
                        "my_static_tag",
                        "my_runtime_tag",
                        "route",
                    ),
                )
                self.gauge.set_default_tags({"my_static_tag": "static_value"})

            def __call__(self):
                self.counter.inc(tags={"my_runtime_tag": "100"})
                self.histogram.observe(200, tags={"my_runtime_tag": "200"})
                self.gauge.set(300, tags={"my_runtime_tag": "300"})
                return [
                    # NOTE(zcin): this is to match the current implementation in
                    # Serve's _add_serve_metric_default_tags().
                    ray.serve.context._INTERNAL_REPLICA_CONTEXT.deployment,
                    ray.serve.context._INTERNAL_REPLICA_CONTEXT.replica_id.unique_id,
                ]

        serve.run(Model.bind(), name="app", route_prefix="/app")
        http_url = get_application_url("HTTP", "app")
        resp = httpx.get(http_url)
        deployment_name, replica_id = resp.json()
        wait_for_condition(
            lambda: len(get_metric_dictionaries("my_gauge")) == 1,
            timeout=40,
        )

        counter_metrics = get_metric_dictionaries("my_counter")
        assert len(counter_metrics) == 1
        expected_metrics = {
            "my_static_tag": "static_value",
            "my_runtime_tag": "100",
            "replica": replica_id,
            "deployment": deployment_name,
            "application": "app",
            "route": "/app",
        }
        self.verify_metrics(counter_metrics[0], expected_metrics)

        expected_metrics = {
            "my_static_tag": "static_value",
            "my_runtime_tag": "300",
            "replica": replica_id,
            "deployment": deployment_name,
            "application": "app",
            "route": "/app",
        }
        gauge_metrics = get_metric_dictionaries("my_gauge")
        assert len(counter_metrics) == 1
        self.verify_metrics(gauge_metrics[0], expected_metrics)

        expected_metrics = {
            "my_static_tag": "static_value",
            "my_runtime_tag": "200",
            "replica": replica_id,
            "deployment": deployment_name,
            "application": "app",
            "route": "/app",
        }
        histogram_metrics = get_metric_dictionaries("my_histogram_sum")
        assert len(histogram_metrics) == 1
        self.verify_metrics(histogram_metrics[0], expected_metrics)

    @pytest.mark.parametrize("use_actor", [False, True])
    def test_serve_metrics_outside_serve(self, use_actor, metrics_start_shutdown):
        """Make sure ray.serve.metrics work in ray actor"""
        if use_actor:

            @ray.remote
            class MyActor:
                def __init__(self):
                    self.counter = Counter(
                        "my_counter",
                        description="my counter metrics",
                        tag_keys=(
                            "my_static_tag",
                            "my_runtime_tag",
                        ),
                    )
                    self.counter.set_default_tags({"my_static_tag": "static_value"})
                    self.histogram = Histogram(
                        "my_histogram",
                        description=("my histogram "),
                        boundaries=DEFAULT_LATENCY_BUCKET_MS,
                        tag_keys=(
                            "my_static_tag",
                            "my_runtime_tag",
                        ),
                    )
                    self.histogram.set_default_tags({"my_static_tag": "static_value"})
                    self.gauge = Gauge(
                        "my_gauge",
                        description=("my_gauge"),
                        tag_keys=(
                            "my_static_tag",
                            "my_runtime_tag",
                        ),
                    )
                    self.gauge.set_default_tags({"my_static_tag": "static_value"})

                def test(self):
                    self.counter.inc(tags={"my_runtime_tag": "100"})
                    self.histogram.observe(200, tags={"my_runtime_tag": "200"})
                    self.gauge.set(300, tags={"my_runtime_tag": "300"})
                    return "hello"

        else:
            counter = Counter(
                "my_counter",
                description="my counter metrics",
                tag_keys=(
                    "my_static_tag",
                    "my_runtime_tag",
                ),
            )
            histogram = Histogram(
                "my_histogram",
                description=("my histogram "),
                boundaries=DEFAULT_LATENCY_BUCKET_MS,
                tag_keys=(
                    "my_static_tag",
                    "my_runtime_tag",
                ),
            )
            gauge = Gauge(
                "my_gauge",
                description=("my_gauge"),
                tag_keys=(
                    "my_static_tag",
                    "my_runtime_tag",
                ),
            )

            @ray.remote
            def fn():
                counter.set_default_tags({"my_static_tag": "static_value"})
                histogram.set_default_tags({"my_static_tag": "static_value"})
                gauge.set_default_tags({"my_static_tag": "static_value"})
                counter.inc(tags={"my_runtime_tag": "100"})
                histogram.observe(200, tags={"my_runtime_tag": "200"})
                gauge.set(300, tags={"my_runtime_tag": "300"})
                return "hello"

        @serve.deployment
        class Model:
            def __init__(self):
                if use_actor:
                    self.my_actor = MyActor.remote()

            async def __call__(self):
                if use_actor:
                    return await self.my_actor.test.remote()
                else:
                    return await fn.remote()

        serve.run(Model.bind(), name="app", route_prefix="/app")
        http_url = get_application_url("HTTP", "app")
        resp = httpx.get(http_url)
        assert resp.text == "hello"
        wait_for_condition(
            lambda: len(get_metric_dictionaries("my_gauge")) == 1,
            timeout=40,
        )

        counter_metrics = get_metric_dictionaries("my_counter")
        assert len(counter_metrics) == 1
        expected_metrics = {
            "my_static_tag": "static_value",
            "my_runtime_tag": "100",
        }
        self.verify_metrics(counter_metrics[0], expected_metrics)

        gauge_metrics = get_metric_dictionaries("my_gauge")
        assert len(counter_metrics) == 1
        expected_metrics = {
            "my_static_tag": "static_value",
            "my_runtime_tag": "300",
        }
        self.verify_metrics(gauge_metrics[0], expected_metrics)

        histogram_metrics = get_metric_dictionaries("my_histogram_sum")
        assert len(histogram_metrics) == 1
        expected_metrics = {
            "my_static_tag": "static_value",
            "my_runtime_tag": "200",
        }
        self.verify_metrics(histogram_metrics[0], expected_metrics)


class TestHandleMetrics:
    def test_queued_queries_basic(self, metrics_start_shutdown):
        signal = SignalActor.options(name="signal123").remote()
        serve.run(WaitForSignal.options(max_ongoing_requests=1).bind(), name="app1")

        # First call should get assigned to a replica
        # call.remote("WaitForSignal", "app1")
        caller = CallActor.remote("WaitForSignal", "app1")
        caller.call.remote()

        for i in range(5):
            # call.remote("WaitForSignal", "app1")
            # c.call.remote()
            caller.call.remote()
            wait_for_condition(
                check_sum_metric_eq,
                metric_name="ray_serve_deployment_queued_queries",
                tags={"application": "app1"},
                expected=i + 1,
            )

        # Release signal
        ray.get(signal.send.remote())
        wait_for_condition(
            check_sum_metric_eq,
            metric_name="ray_serve_deployment_queued_queries",
            tags={"application": "app1", "deployment": "WaitForSignal"},
            expected=0,
        )

    def test_queued_queries_multiple_handles(self, metrics_start_shutdown):
        signal = SignalActor.options(name="signal123").remote()
        serve.run(WaitForSignal.options(max_ongoing_requests=1).bind(), name="app1")

        # Send first request
        call.remote("WaitForSignal", "app1")
        wait_for_condition(
            check_sum_metric_eq,
            metric_name="ray_serve_deployment_queued_queries",
            tags={"application": "app1", "deployment": "WaitForSignal"},
            expected=0,
        )

        # Send second request (which should stay queued)
        call.remote("WaitForSignal", "app1")
        wait_for_condition(
            check_sum_metric_eq,
            metric_name="ray_serve_deployment_queued_queries",
            tags={"application": "app1", "deployment": "WaitForSignal"},
            expected=1,
        )

        # Send third request (which should stay queued)
        call.remote("WaitForSignal", "app1")
        wait_for_condition(
            check_sum_metric_eq,
            metric_name="ray_serve_deployment_queued_queries",
            tags={"application": "app1", "deployment": "WaitForSignal"},
            expected=2,
        )

        # Release signal
        ray.get(signal.send.remote())
        wait_for_condition(
            check_sum_metric_eq,
            metric_name="ray_serve_deployment_queued_queries",
            tags={"application": "app1", "deployment": "WaitForSignal"},
            expected=0,
        )

    def test_queued_queries_disconnected(self, metrics_start_shutdown):
        """Check that disconnected queued queries are tracked correctly."""

        signal = SignalActor.remote()

        @serve.deployment(
            max_ongoing_requests=1,
        )
        async def hang_on_first_request():
            await signal.wait.remote()

        serve.run(hang_on_first_request.bind())

        print("Deployed hang_on_first_request deployment.")

        wait_for_condition(
            check_metric_float_eq,
            timeout=15,
            metric="ray_serve_num_scheduling_tasks",
            # Router is eagerly created on HTTP proxy, so there are metrics emitted
            # from proxy router
            expected=0,
            # TODO(zcin): this tag shouldn't be necessary, there shouldn't be a mix of
            # metrics from new and old sessions.
            expected_tags={
                "SessionName": ray._private.worker.global_worker.node.session_name
            },
        )
        print("ray_serve_num_scheduling_tasks updated successfully.")
        wait_for_condition(
            check_metric_float_eq,
            timeout=15,
            metric="serve_num_scheduling_tasks_in_backoff",
            # Router is eagerly created on HTTP proxy, so there are metrics emitted
            # from proxy router
            expected=0,
            # TODO(zcin): this tag shouldn't be necessary, there shouldn't be a mix of
            # metrics from new and old sessions.
            expected_tags={
                "SessionName": ray._private.worker.global_worker.node.session_name
            },
        )
        print("serve_num_scheduling_tasks_in_backoff updated successfully.")

        @ray.remote(num_cpus=0)
        def do_request():
            r = httpx.get("http://localhost:8000/", timeout=10)
            r.raise_for_status()
            return r

        # Make a request to block the deployment from accepting other requests.
        request_refs = [do_request.remote()]
        wait_for_condition(
            lambda: ray.get(signal.cur_num_waiters.remote()) == 1, timeout=10
        )

        print("First request is executing.")
        wait_for_condition(
            check_sum_metric_eq,
            timeout=15,
            metric_name="ray_serve_num_ongoing_http_requests",
            expected=1,
        )
        print("ray_serve_num_ongoing_http_requests updated successfully.")

        num_queued_requests = 3
        request_refs.extend([do_request.remote() for _ in range(num_queued_requests)])
        print(f"{num_queued_requests} more requests now queued.")

        # First request should be processing. All others should be queued.
        wait_for_condition(
            check_sum_metric_eq,
            timeout=15,
            metric_name="ray_serve_deployment_queued_queries",
            expected=num_queued_requests,
        )
        print("ray_serve_deployment_queued_queries updated successfully.")
        wait_for_condition(
            check_sum_metric_eq,
            timeout=15,
            metric_name="ray_serve_num_ongoing_http_requests",
            expected=num_queued_requests + 1,
        )
        print("ray_serve_num_ongoing_http_requests updated successfully.")

        # There should be 2 scheduling tasks (which is the max, since
        # 2 = 2 * 1 replica) that are attempting to schedule the hanging requests.
        wait_for_condition(
            check_sum_metric_eq,
            timeout=15,
            metric_name="ray_serve_num_scheduling_tasks",
            expected=2,
        )
        print("ray_serve_num_scheduling_tasks updated successfully.")
        wait_for_condition(
            check_sum_metric_eq,
            timeout=15,
            metric_name="ray_serve_num_scheduling_tasks_in_backoff",
            expected=2,
        )
        print("serve_num_scheduling_tasks_in_backoff updated successfully.")

        # Disconnect all requests by cancelling the Ray tasks.
        [ray.cancel(ref, force=True) for ref in request_refs]
        print("Cancelled all HTTP requests.")

        wait_for_condition(
            check_sum_metric_eq,
            timeout=15,
            metric_name="ray_serve_deployment_queued_queries",
            expected=0,
        )
        print("ray_serve_deployment_queued_queries updated successfully.")

        # Task should get cancelled.
        wait_for_condition(
            check_sum_metric_eq,
            timeout=15,
            metric_name="ray_serve_num_ongoing_http_requests",
            expected=0,
        )
        print("ray_serve_num_ongoing_http_requests updated successfully.")

        wait_for_condition(
            check_sum_metric_eq,
            timeout=15,
            metric_name="ray_serve_num_scheduling_tasks",
            expected=0,
        )
        print("ray_serve_num_scheduling_tasks updated successfully.")
        wait_for_condition(
            check_sum_metric_eq,
            timeout=15,
            metric_name="ray_serve_num_scheduling_tasks_in_backoff",
            expected=0,
        )
        print("serve_num_scheduling_tasks_in_backoff updated successfully.")

        # Unblock hanging request.
        ray.get(signal.send.remote())

    def test_running_requests_gauge(self, metrics_start_shutdown):
        signal = SignalActor.options(name="signal123").remote()
        serve.run(
            Router.options(num_replicas=2, ray_actor_options={"num_cpus": 0}).bind(
                [
                    WaitForSignal.options(
                        name="d1",
                        ray_actor_options={"num_cpus": 0},
                        max_ongoing_requests=2,
                        num_replicas=3,
                    ).bind(),
                    WaitForSignal.options(
                        name="d2",
                        ray_actor_options={"num_cpus": 0},
                        max_ongoing_requests=2,
                        num_replicas=3,
                    ).bind(),
                ],
            ),
            name="app1",
        )

        requests_sent = {1: 0, 2: 0}
        for i in range(5):
            index = random.choice([1, 2])
            print(f"Sending request to d{index}")
            call.remote("Router", "app1", index)
            requests_sent[index] += 1

            wait_for_condition(
                check_sum_metric_eq,
                metric_name="ray_serve_num_ongoing_requests_at_replicas",
                tags={"application": "app1", "deployment": "d1"},
                expected=requests_sent[1],
            )

            wait_for_condition(
                check_sum_metric_eq,
                metric_name="ray_serve_num_ongoing_requests_at_replicas",
                tags={"application": "app1", "deployment": "d2"},
                expected=requests_sent[2],
            )

            wait_for_condition(
                check_sum_metric_eq,
                metric_name="ray_serve_num_ongoing_requests_at_replicas",
                tags={"application": "app1", "deployment": "Router"},
                expected=i + 1,
            )

        # Release signal, the number of running requests should drop to 0
        ray.get(signal.send.remote())
        wait_for_condition(
            check_sum_metric_eq,
            metric_name="ray_serve_num_ongoing_requests_at_replicas",
            tags={"application": "app1"},
            expected=0,
        )


if __name__ == "__main__":
    sys.exit(pytest.main(["-v", "-s", __file__]))
